#!/usr/bin/env python
# -*- coding:utf-8 -*-
from keras.layers import Dense, Dropout, Embedding, LSTM
from keras.models import Sequential

from const import EMBEDDING_DIM, MAX_SEQUENCE_LENGTH

__all__ = ['load_lstm_model']


def load_lstm_model(nb_words, embedding_weights, embedding_dim=EMBEDDING_DIM, max_sequence_length=MAX_SEQUENCE_LENGTH):
    """LSTM模型：LSTM层 -> LSTM层 -> LSTM层 -> 全连接层 -> 全连接层"""
    model = Sequential()  # 网络模型

    # 输入层
    model.add(Embedding(input_dim=nb_words, output_dim=embedding_dim, mask_zero=True, input_length=max_sequence_length,
                        weights=[embedding_weights]))

    # LSTM层
    model.add(LSTM(32, return_sequences=True, activation='tanh'))
    # model.add(Dropout(0.5))
    model.add(LSTM(32, return_sequences=True, activation='tanh'))
    model.add(Dropout(0.2))
    model.add(LSTM(256, activation='tanh'))
    model.add(Dropout(0.5))

    # 全连接层
    model.add(Dense(128, activation='tanh'))
    model.add(Dropout(0.2))
    model.add(Dense(3, activation='softmax'))  # 最终输出维度 3

    # 网络配置
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['categorical_accuracy'])

    return model
